import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Flatten


class ResidualBlock(nn.Module):
    def __init__(self, channel):
        super(ResidualBlock, self).__init__()
        self.channel = channel
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=channel,   # (1, 34, 2) -> (1, 35, 3)
                      out_channels=channel,
                      kernel_size=(2, 2),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(channel),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=(2, 2), stride=(1, 1)),
            # nn.BatchNorm2d(channel)
        )

    def forward(self, x):
        out = self.conv1(x)  # (1, 34, 2) -> (1, 35, 3)
        out = self.conv2(out)  # (1, 35, 3) -> (1, 34, 2)
        out += x
        out = F.relu(out)
        return out


# 残差网络
class LResNet(nn.Module):
    def __init__(self):
        super(LResNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3)),  # (1, 69, 6) -> (1, 67, 4)
            nn.BatchNorm2d(32),  # (1, 67, 4)
            nn.ReLU(),
            nn.MaxPool2d(2, ceil_mode=True)  # (1, 67, 4) -> (1, 34, 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(2, 2), padding=1),  # (16,8,8)
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.MaxPool2d(2, ceil_mode=True)  # (16,4,4)
        )
        self.reslayer1 = ResidualBlock(32)
        self.reslayer2 = ResidualBlock(16)
        self.flatten = Flatten()
        self.fc = nn.Linear(36, 2)  # 这里的输入256是因为16*4*4=256

    def forward(self, x):
        out = self.conv1(x)
        out = self.reslayer1(out)  # (32, 34, 2) ->
        out = self.conv2(out)
        # out = self.reslayer2(out)
        out = out.view(out.size(0), -1)
        out = self.flatten(out)
        out = self.fc(out)
        return out


if __name__ == '__main__':
    lResNet = LResNet()
    inputs = torch.ones((10, 1, 69, 6))
    outputs = lResNet(inputs)
    print(outputs.shape)

